from __future__ import annotations

import argparse
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import pandas as pd
import urllib.request


@dataclass(frozen=True)
class Indicator:
    code: str
    name: str


def fetch_worldbank_indicator(code: str, per_page: int = 20000, timeout: int = 120) -> list[dict[str, Any]]:
    url = f"https://api.worldbank.org/v2/country/all/indicator/{code}?format=json&per_page={per_page}"
    with urllib.request.urlopen(url, timeout=timeout) as resp:
        raw = resp.read().decode("utf-8")
    data = json.loads(raw)
    if not isinstance(data, list) or len(data) < 2:
        raise ValueError(f"Unexpected World Bank response for {code}")
    return data[1]


def to_baseline(
    records: list[dict[str, Any]],
    indicator_code: str,
    baseline_year: int,
    min_year: int,
) -> pd.DataFrame:
    rows = []
    for r in records:
        iso3 = r.get("countryiso3code")
        year = r.get("date")
        value = r.get("value")
        if not iso3 or not isinstance(iso3, str) or len(iso3) != 3:
            continue
        try:
            y = int(year)
        except Exception:
            continue
        if y > baseline_year or y < min_year:
            continue
        rows.append({"iso3": iso3, "year": y, "indicator": indicator_code, "value": value})

    df = pd.DataFrame(rows)
    if df.empty:
        return df

    df["value"] = pd.to_numeric(df["value"], errors="coerce")
    df = df.dropna(subset=["value"]).copy()
    df = df.sort_values(["iso3", "year"])
    baseline = df.groupby("iso3", as_index=False).tail(1).copy()
    baseline = baseline.rename(columns={"year": f"{indicator_code}__year", "value": indicator_code})
    return baseline[["iso3", indicator_code, f"{indicator_code}__year"]]


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--baseline-year", type=int, default=2019)
    parser.add_argument("--min-year", type=int, default=2000)
    parser.add_argument("--out", type=Path, default=Path("outputs/data/worldbank_baseline_2019.csv"))
    args = parser.parse_args()

    indicators = [
        Indicator("SP.REG.DTHS.ZS", "Death registration completeness with cause-of-death (%)"),
        Indicator("GE.EST", "WGI Government Effectiveness (estimate)"),
        Indicator("RL.EST", "WGI Rule of Law (estimate)"),
        Indicator("CC.EST", "WGI Control of Corruption (estimate)"),
        Indicator("VA.EST", "WGI Voice and Accountability (estimate)"),
        Indicator("PV.EST", "WGI Political Stability (estimate)"),
        Indicator("RQ.EST", "WGI Regulatory Quality (estimate)"),
    ]

    baselines = []
    for ind in indicators:
        records = fetch_worldbank_indicator(ind.code)
        b = to_baseline(records, ind.code, baseline_year=args.baseline_year, min_year=args.min_year)
        baselines.append(b)

    merged = None
    for b in baselines:
        if merged is None:
            merged = b
        else:
            merged = merged.merge(b, how="outer", on="iso3")

    if merged is None:
        raise RuntimeError("No indicators processed.")

    rename = {
        "SP.REG.DTHS.ZS": "death_reg_cod_pct",
        "SP.REG.DTHS.ZS__year": "death_reg_cod_year",
        "GE.EST": "wgi_gov_effectiveness",
        "GE.EST__year": "wgi_gov_effectiveness_year",
        "RL.EST": "wgi_rule_of_law",
        "RL.EST__year": "wgi_rule_of_law_year",
        "CC.EST": "wgi_control_corruption",
        "CC.EST__year": "wgi_control_corruption_year",
        "VA.EST": "wgi_voice_accountability",
        "VA.EST__year": "wgi_voice_accountability_year",
        "PV.EST": "wgi_political_stability",
        "PV.EST__year": "wgi_political_stability_year",
        "RQ.EST": "wgi_regulatory_quality",
        "RQ.EST__year": "wgi_regulatory_quality_year",
    }
    merged = merged.rename(columns=rename)
    merged["baseline_year_target"] = args.baseline_year

    args.out.parent.mkdir(parents=True, exist_ok=True)
    merged.to_csv(args.out, index=False, encoding="utf-8")
    print(f"Wrote {args.out} rows={len(merged):,} cols={len(merged.columns):,}")


if __name__ == "__main__":
    main()

